{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "套路：\n",
    "- 准备数据\n",
    "- 实现算法\n",
    "- 测试算法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "heading_collapsed": true
   },
   "source": [
    "# 任务1：亲和性分析\n",
    "- 如果一个顾客买了商品X，那么他们可能愿意买商品Y\n",
    "\n",
    "衡量方法：\n",
    "- 支持度support := 所有买X的人数\n",
    "\n",
    "- 置信度confidence := $\\frac{所有买X和Y的人数}{所有买X的人数}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# 引入库\n",
    "import numpy as np\n",
    "from operator import itemgetter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(100, 5)\n",
      "[[0. 0. 1. 1. 0.]\n",
      " [1. 1. 0. 0. 0.]\n",
      " [1. 0. 0. 1. 1.]\n",
      " [0. 1. 1. 0. 1.]\n",
      " [0. 1. 0. 0. 0.]]\n"
     ]
    }
   ],
   "source": [
    "# 准备数据\n",
    "\n",
    "# 创造随机生成的数据　(可跳过)\n",
    "X = np.zeros((100, 5), dtype='bool')\n",
    "for i in range(X.shape[0]):\n",
    "    if np.random.random() < 0.3:\n",
    "        # A bread winner\n",
    "        X[i][0] = 1\n",
    "        if np.random.random() < 0.5:\n",
    "            # Who likes milk\n",
    "            X[i][1] = 1\n",
    "        if np.random.random() < 0.2:\n",
    "            # Who likes cheese\n",
    "            X[i][2] = 1\n",
    "        if np.random.random() < 0.25:\n",
    "            # Who likes apples\n",
    "            X[i][3] = 1\n",
    "        if np.random.random() < 0.5:\n",
    "            # Who likes bananas\n",
    "            X[i][4] = 1\n",
    "    else:\n",
    "        # Not a bread winner\n",
    "        if np.random.random() < 0.5:\n",
    "            # Who likes milk\n",
    "            X[i][1] = 1\n",
    "            if np.random.random() < 0.2:\n",
    "                # Who likes cheese\n",
    "                X[i][2] = 1\n",
    "            if np.random.random() < 0.25:\n",
    "                # Who likes apples\n",
    "                X[i][3] = 1\n",
    "            if np.random.random() < 0.5:\n",
    "                # Who likes bananas\n",
    "                X[i][4] = 1\n",
    "        else:\n",
    "            if np.random.random() < 0.8:\n",
    "                # Who likes cheese\n",
    "                X[i][2] = 1\n",
    "            if np.random.random() < 0.6:\n",
    "                # Who likes apples\n",
    "                X[i][3] = 1\n",
    "            if np.random.random() < 0.7:\n",
    "                # Who likes bananas\n",
    "                X[i][4] = 1\n",
    "    if X[i].sum() == 0:\n",
    "        X[i][4] = 1  # Must buy something, so gets bananas\n",
    "np.savetxt(\"./data/affinity_dataset.txt\", X, fmt='%d') # 保存\n",
    "\n",
    "\n",
    "# 读取数据\n",
    "dataset_filename = \"./data/affinity_dataset.txt\"\n",
    "X = np.loadtxt(dataset_filename) # 加载数据\n",
    "n_samples, n_features = X.shape\n",
    "print(X.shape)\n",
    "print(X[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们定义的规则为：买了苹果又买香蕉\n",
    "\n",
    "下面 rule_valid 表示买了苹果又买香蕉的有多少人\n",
    "\n",
    "显然支持度 := 所有买X的人数，即支持度=rule_valid\n",
    "\n",
    "所以置信度=支持度/总人数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "买苹果的有39人\n",
      "买了苹果又买香蕉的有23人\n",
      "买了苹果不买香蕉的有16人\n",
      "支持度support = 23 置信度confidence = 0.590.\n",
      "置信度confidence的百分比形式为 59.0%.\n"
     ]
    }
   ],
   "source": [
    "# 文件affinity_dataset.txt是生成的数据，得我们来指定列\n",
    "features = [\"bread\", \"milk\", \"cheese\", \"apples\", \"bananas\"]\n",
    "\n",
    "num_apple_purchases = 0 # 计数\n",
    "for sample in X:\n",
    "    if sample[3] == 1:  # 记录买 Apples　的有多少人\n",
    "        num_apple_purchases += 1\n",
    "print(\"买苹果的有{0}人\".format(num_apple_purchases))\n",
    "\n",
    "rule_valid = 0\n",
    "rule_invalid = 0\n",
    "for sample in X:\n",
    "    if sample[3] == 1:  # 买了苹果\n",
    "        if sample[4] == 1:# 又买香蕉的\n",
    "            rule_valid += 1\n",
    "        else:# 不买香蕉的\n",
    "            rule_invalid += 1\n",
    "print(\"买了苹果又买香蕉的有{0}人\".format(rule_valid))\n",
    "print(\"买了苹果不买香蕉的有{0}人\".format(rule_invalid))\n",
    "\n",
    "# 计算支持度support和置信度confidence\n",
    "support = rule_valid  # 支持度是符合“买了苹果又买香蕉”这个规则的人数\n",
    "confidence = rule_valid / num_apple_purchases\n",
    "print(\"支持度support = {0} 置信度confidence = {1:.3f}.\".format(support, confidence))\n",
    "# 置信度的百分比形式\n",
    "print(\"置信度confidence的百分比形式为 {0:.1f}%.\".format(100 * confidence))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "# 上面\"买了苹果又买香蕉\"是一种情况，现在把所有可能的情况都做一遍\n",
    "valid_rules = defaultdict(int)\n",
    "invalid_rules = defaultdict(int)\n",
    "num_occurences = defaultdict(int)\n",
    "\n",
    "for sample in X:\n",
    "    for premise in range(n_features):\n",
    "        if sample[premise] == 0: continue\n",
    "        # 先买premise，premise代表一种食物，记做X\n",
    "        num_occurences[premise] += 1\n",
    "        for conclusion in range(n_features):\n",
    "            if premise == conclusion:  \n",
    "                continue # 跳过买X又买X的情况\n",
    "            if sample[conclusion] == 1: # 又买了conclusion，conclusion代表一种食物，记做Y\n",
    "                valid_rules[(premise, conclusion)] += 1 # 买X买Y\n",
    "            else: \n",
    "                invalid_rules[(premise, conclusion)] += 1 # 买X没买Y\n",
    "support = valid_rules\n",
    "confidence = defaultdict(float)\n",
    "for premise, conclusion in valid_rules.keys():\n",
    "    confidence[(premise, conclusion)] = valid_rules[(premise, conclusion)] / num_occurences[premise]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "hidden": true,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rule: 买了cheese，又买apples\n",
      " - 置信度Confidence: 0.553\n",
      " - 支持度Support: 26\n",
      "\n",
      "Rule: 买了apples，又买cheese\n",
      " - 置信度Confidence: 0.667\n",
      " - 支持度Support: 26\n",
      "\n",
      "Rule: 买了bread，又买milk\n",
      " - 置信度Confidence: 0.619\n",
      " - 支持度Support: 13\n",
      "\n",
      "Rule: 买了milk，又买bread\n",
      " - 置信度Confidence: 0.265\n",
      " - 支持度Support: 13\n",
      "\n",
      "Rule: 买了bread，又买apples\n",
      " - 置信度Confidence: 0.286\n",
      " - 支持度Support: 6\n",
      "\n",
      "Rule: 买了bread，又买bananas\n",
      " - 置信度Confidence: 0.476\n",
      " - 支持度Support: 10\n",
      "\n",
      "Rule: 买了apples，又买bread\n",
      " - 置信度Confidence: 0.154\n",
      " - 支持度Support: 6\n",
      "\n",
      "Rule: 买了apples，又买bananas\n",
      " - 置信度Confidence: 0.590\n",
      " - 支持度Support: 23\n",
      "\n",
      "Rule: 买了bananas，又买bread\n",
      " - 置信度Confidence: 0.185\n",
      " - 支持度Support: 10\n",
      "\n",
      "Rule: 买了bananas，又买apples\n",
      " - 置信度Confidence: 0.426\n",
      " - 支持度Support: 23\n",
      "\n",
      "Rule: 买了milk，又买cheese\n",
      " - 置信度Confidence: 0.204\n",
      " - 支持度Support: 10\n",
      "\n",
      "Rule: 买了milk，又买bananas\n",
      " - 置信度Confidence: 0.429\n",
      " - 支持度Support: 21\n",
      "\n",
      "Rule: 买了cheese，又买milk\n",
      " - 置信度Confidence: 0.213\n",
      " - 支持度Support: 10\n",
      "\n",
      "Rule: 买了cheese，又买bananas\n",
      " - 置信度Confidence: 0.532\n",
      " - 支持度Support: 25\n",
      "\n",
      "Rule: 买了bananas，又买milk\n",
      " - 置信度Confidence: 0.389\n",
      " - 支持度Support: 21\n",
      "\n",
      "Rule: 买了bananas，又买cheese\n",
      " - 置信度Confidence: 0.463\n",
      " - 支持度Support: 25\n",
      "\n",
      "Rule: 买了bread，又买cheese\n",
      " - 置信度Confidence: 0.238\n",
      " - 支持度Support: 5\n",
      "\n",
      "Rule: 买了cheese，又买bread\n",
      " - 置信度Confidence: 0.106\n",
      " - 支持度Support: 5\n",
      "\n",
      "Rule: 买了milk，又买apples\n",
      " - 置信度Confidence: 0.184\n",
      " - 支持度Support: 9\n",
      "\n",
      "Rule: 买了apples，又买milk\n",
      " - 置信度Confidence: 0.231\n",
      " - 支持度Support: 9\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for premise, conclusion in confidence:\n",
    "    premise_name = features[premise]\n",
    "    conclusion_name = features[conclusion]\n",
    "    print(\"Rule: 买了{0}，又买{1}\".format(premise_name, conclusion_name))\n",
    "    print(\" - 置信度Confidence: {0:.3f}\".format(confidence[(premise, conclusion)]))\n",
    "    print(\" - 支持度Support: {0}\".format(support[(premise, conclusion)]))\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rule: 买了milk，又买apples\n",
      " - 置信度Confidence: 0.184\n",
      " - 支持度Support: 9\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 封装一下方便调用\n",
    "def print_rule(premise, conclusion, support, confidence, features):\n",
    "    premise_name = features[premise]\n",
    "    conclusion_name = features[conclusion]\n",
    "    print(\"Rule: 买了{0}，又买{1}\".format(premise_name, conclusion_name))\n",
    "    print(\" - 置信度Confidence: {0:.3f}\".format(confidence[(premise, conclusion)]))\n",
    "    print(\" - 支持度Support: {0}\".format(support[(premise, conclusion)]))\n",
    "    print(\"\")\n",
    "    \n",
    "premise = 1\n",
    "conclusion = 3\n",
    "print_rule(premise, conclusion, support, confidence, features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "hidden": true,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[((2, 3), 26),\n",
      " ((3, 2), 26),\n",
      " ((0, 1), 13),\n",
      " ((1, 0), 13),\n",
      " ((0, 3), 6),\n",
      " ((0, 4), 10),\n",
      " ((3, 0), 6),\n",
      " ((3, 4), 23),\n",
      " ((4, 0), 10),\n",
      " ((4, 3), 23),\n",
      " ((1, 2), 10),\n",
      " ((1, 4), 21),\n",
      " ((2, 1), 10),\n",
      " ((2, 4), 25),\n",
      " ((4, 1), 21),\n",
      " ((4, 2), 25),\n",
      " ((0, 2), 5),\n",
      " ((2, 0), 5),\n",
      " ((1, 3), 9),\n",
      " ((3, 1), 9)]\n"
     ]
    }
   ],
   "source": [
    "# 按支持度support排序\n",
    "from pprint import pprint\n",
    "pprint(list(support.items()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rule #1\n",
      "Rule: 买了apples，又买cheese\n",
      " - 置信度Confidence: 0.667\n",
      " - 支持度Support: 26\n",
      "\n",
      "Rule #2\n",
      "Rule: 买了bread，又买milk\n",
      " - 置信度Confidence: 0.619\n",
      " - 支持度Support: 13\n",
      "\n",
      "Rule #3\n",
      "Rule: 买了apples，又买bananas\n",
      " - 置信度Confidence: 0.590\n",
      " - 支持度Support: 23\n",
      "\n",
      "Rule #4\n",
      "Rule: 买了cheese，又买apples\n",
      " - 置信度Confidence: 0.553\n",
      " - 支持度Support: 26\n",
      "\n",
      "Rule #5\n",
      "Rule: 买了cheese，又买bananas\n",
      " - 置信度Confidence: 0.532\n",
      " - 支持度Support: 25\n",
      "\n"
     ]
    }
   ],
   "source": [
    "sorted_confidence = sorted(confidence.items(), key=itemgetter(1), reverse=True)\n",
    "for index in range(5): # 打印前5个\n",
    "    print(\"Rule #{0}\".format(index + 1))\n",
    "    (premise, conclusion) = sorted_confidence[index][0]\n",
    "    print_rule(premise, conclusion, support, confidence, features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# 任务2：Iris植物分类\n",
    "- 给出某一植物部分特征，预测该植物属于哪一类\n",
    "\n",
    "特征：\n",
    "- 萼片长宽sepal width, sepal height\n",
    "- 花瓣长宽petal width, petal height\n",
    "\n",
    "算法：\n",
    "* For each variable\n",
    "    * For each value of the variable\n",
    "        * The prediction based on this variable goes the most frequent class\n",
    "        * Compute the error of this prediction\n",
    "    * Sum the prediction errors for all values of the variable\n",
    "* Use the variable with the lowest error\n",
    "\n",
    "中文算法：\n",
    "* For 给定的每个特征\n",
    "    * For 该特征对应的真值（即植物是哪一类）\n",
    "        * 预测值：基于该特征预测的次数最多的类，即在所有样本里该特征 10 次有 6 次预测了 A 类，那我们对所有样本都预测为 A 类\n",
    "        * 计算预测值与真值的误差\n",
    "    * 对上面计算的误差求和\n",
    "* 使用误差最小的特征作为最终模型\n",
    "\n",
    "很显然，“在所有样本里该特征 10 次有 6 次预测了 A 类，那我们对所有样本都预测为 A 类”是基于大数据的\n",
    "\n",
    "不过这样的规则过于简单，下面继续实验会发现准确率只有 60% 左右，当然还是比随机预测 50% 好！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iris Plants Database\n",
      "====================\n",
      "\n",
      "Notes\n",
      "-----\n",
      "Data Set Characteristics:\n",
      "    :Number of Instances: 150 (50 in each of three classes)\n",
      "    :Number of Attributes: 4 numeric, predictive attributes and the class\n",
      "    :Attribute Information:\n",
      "        - sepal length in cm\n",
      "        - sepal width in cm\n",
      "        - petal length in cm\n",
      "        - petal width in cm\n",
      "        - class:\n",
      "                - Iris-Setosa\n",
      "                - Iris-Versicolour\n",
      "                - Iris-Virginica\n",
      "    :Summary Statistics:\n",
      "\n",
      "    ============== ==== ==== ======= ===== ====================\n",
      "                    Min  Max   Mean    SD   Class Correlation\n",
      "    ============== ==== ==== ======= ===== ====================\n",
      "    sepal length:   4.3  7.9   5.84   0.83    0.7826\n",
      "    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n",
      "    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n",
      "    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)\n",
      "    ============== ==== ==== ======= ===== ====================\n",
      "\n",
      "    :Missing Attribute Values: None\n",
      "    :Class Distribution: 33.3% for each of 3 classes.\n",
      "    :Creator: R.A. Fisher\n",
      "    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n",
      "    :Date: July, 1988\n",
      "\n",
      "This is a copy of UCI ML iris datasets.\n",
      "http://archive.ics.uci.edu/ml/datasets/Iris\n",
      "\n",
      "The famous Iris database, first used by Sir R.A Fisher\n",
      "\n",
      "This is perhaps the best known database to be found in the\n",
      "pattern recognition literature.  Fisher's paper is a classic in the field and\n",
      "is referenced frequently to this day.  (See Duda & Hart, for example.)  The\n",
      "data set contains 3 classes of 50 instances each, where each class refers to a\n",
      "type of iris plant.  One class is linearly separable from the other 2; the\n",
      "latter are NOT linearly separable from each other.\n",
      "\n",
      "References\n",
      "----------\n",
      "   - Fisher,R.A. \"The use of multiple measurements in taxonomic problems\"\n",
      "     Annual Eugenics, 7, Part II, 179-188 (1936); also in \"Contributions to\n",
      "     Mathematical Statistics\" (John Wiley, NY, 1950).\n",
      "   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\n",
      "     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n",
      "   - Dasarathy, B.V. (1980) \"Nosing Around the Neighborhood: A New System\n",
      "     Structure and Classification Rule for Recognition in Partially Exposed\n",
      "     Environments\".  IEEE Transactions on Pattern Analysis and Machine\n",
      "     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n",
      "   - Gates, G.W. (1972) \"The Reduced Nearest Neighbor Rule\".  IEEE Transactions\n",
      "     on Information Theory, May 1972, 431-433.\n",
      "   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al\"s AUTOCLASS II\n",
      "     conceptual clustering system finds 3 classes in the data.\n",
      "   - Many, many more ...\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "#X, y = np.loadtxt(\"X_classification.txt\"), np.loadtxt(\"y_classification.txt\") # 本地加载数据，我先下载好在 data 文件夹里了\n",
    "dataset = load_iris() # 或者自己亲自下载数据再加载也行\n",
    "X = dataset.data\n",
    "y = dataset.target\n",
    "print(dataset.DESCR) # 打印下数据集介绍\n",
    "n_samples, n_features = X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# Compute the mean for each attribute计算平均值\n",
    "attribute_means = X.mean(axis=0)\n",
    "assert attribute_means.shape == (n_features,)\n",
    "X_d = np.array(X >= attribute_means, dtype='int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集数据有 (112,) 条\n",
      "测试集数据有 (38,) 条\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lxy/.local/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "# 划分训练集和测试集\n",
    "from sklearn.cross_validation import train_test_split\n",
    "\n",
    "# 设置随机数种子以便复现书里的内容\n",
    "random_state = 14\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X_d, y, random_state=random_state)\n",
    "print(\"训练集数据有 {} 条\".format(y_train.shape))\n",
    "print(\"测试集数据有 {} 条\".format(y_test.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from operator import itemgetter\n",
    "\n",
    "\n",
    "def train(X, y_true, feature):\n",
    "    \"\"\"Computes the predictors and error for a given feature using the OneR algorithm\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    X: array [n_samples, n_features]\n",
    "        The two dimensional array that holds the dataset. Each row is a sample, each column\n",
    "        is a feature.\n",
    "    \n",
    "    y_true: array [n_samples,]\n",
    "        The one dimensional array that holds the class values. Corresponds to X, such that\n",
    "        y_true[i] is the class value for sample X[i].\n",
    "    \n",
    "    feature: int\n",
    "        An integer corresponding to the index of the variable we wish to test.\n",
    "        0 <= variable < n_features\n",
    "        \n",
    "    Returns\n",
    "    -------\n",
    "    predictors: dictionary of tuples: (value, prediction)\n",
    "        For each item in the array, if the variable has a given value, make the given prediction.\n",
    "    \n",
    "    error: float\n",
    "        The ratio of training data that this rule incorrectly predicts.\n",
    "    \"\"\"\n",
    "    # 1.一些等下要用的变量（数据的形状如上）\n",
    "    n_samples, n_features = X.shape\n",
    "    assert 0 <= feature < n_features\n",
    "    values = set(X[:,feature])\n",
    "    predictors = dict()\n",
    "    errors = []\n",
    "    \n",
    "    # 2.算法（对照上面的算法流程）\n",
    "    # 已经给定特征 feature，作为函数参数传过来了\n",
    "    for current_value in values: \n",
    "    # For 该特征对应的真值（即植物是哪一类）\n",
    "    \n",
    "        most_frequent_class, error = train_feature_value(X, y_true, feature, current_value) \n",
    "        # 预测值：基于该特征预测的次数最多的类，即在所有样本里该特征 10 次有 6 次预测了 A 类，那我们对所有样本都预测为 A 类\n",
    "        \n",
    "        predictors[current_value] = most_frequent_class\n",
    "        errors.append(error)\n",
    "        # 计算预测值与真值的误差\n",
    "    \n",
    "    total_error = sum(errors)\n",
    "    # 对上面计算的误差求和\n",
    "    # python里求和函数 sum([1, 2, 3]) == 1 + 2 + 3 == 6\n",
    "    \n",
    "    return predictors, total_error\n",
    "\n",
    "# Compute what our predictors say each sample is based on its value\n",
    "#y_predicted = np.array([predictors[sample[feature]] for sample in X])\n",
    "    \n",
    "\n",
    "def train_feature_value(X, y_true, feature, value):\n",
    "    # 预测值：基于该特征预测的次数最多的类，即在所有样本里该特征 10 次有 6 次预测了 A 类，那我们对所有样本都预测为 A 类\n",
    "    # 我们需要一个字典型变量存每个变量预测正确的次数\n",
    "    class_counts = defaultdict(int)\n",
    "    # 对每个二元组（类别，真值）迭代计数\n",
    "    for sample, y in zip(X, y_true):\n",
    "        if sample[feature] == value:\n",
    "            class_counts[y] += 1\n",
    "    # 现在选被预测最多的类别，需要排序。（我们认为被预测最多的类别就是正确的）\n",
    "    sorted_class_counts = sorted(class_counts.items(), key=itemgetter(1), reverse=True)\n",
    "    most_frequent_class = sorted_class_counts[0][0]\n",
    "    # 误差定义为分类“错误”的次数，这里“错误”指样本中没有分类为我们预测的值，即样本的真实类别不是“被预测最多的类别”\n",
    "    n_samples = X.shape[1]\n",
    "    error = sum([class_count for class_value, class_count in class_counts.items()\n",
    "                 if class_value != most_frequent_class])\n",
    "    return most_frequent_class, error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最佳模型基于第 2 个变量，误差为 37.00\n",
      "{'variable': 2, 'predictor': {0: 0, 1: 2}}\n"
     ]
    }
   ],
   "source": [
    "# For 给定的每个特征，计算所有预测值（这里 for 写到 list 里面是 python 的语法糖）\n",
    "all_predictors = {variable: train(X_train, y_train, variable) for variable in range(X_train.shape[1])}\n",
    "errors = {variable: error for variable, (mapping, error) in all_predictors.items()}\n",
    "# 现在选择最佳模型并保存为 \"model\"\n",
    "# 按误差排序\n",
    "best_variable, best_error = sorted(errors.items(), key=itemgetter(1))[0]\n",
    "print(\"最佳模型基于第 {0} 个变量，误差为 {1:.2f}\".format(best_variable, best_error))\n",
    "\n",
    "# 选最好的模型，也就是误差最小的模型\n",
    "model = {'variable': best_variable,\n",
    "         'predictor': all_predictors[best_variable][0]}\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def predict(X_test, model):\n",
    "    variable = model['variable']\n",
    "    predictor = model['predictor']\n",
    "    y_predicted = np.array([predictor[int(sample[variable])] for sample in X_test])\n",
    "    return y_predicted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 0 2 2 2 0 2 0 2 2 0 2 2 0 2 0 2 2 2 0 0 0 2 0 2 0 2 2 0 0 0 2 0 2 0 2\n",
      " 2]\n",
      "在测试集上的准确率 65.8%\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       0.94      1.00      0.97        17\n",
      "          1       0.00      0.00      0.00        13\n",
      "          2       0.40      1.00      0.57         8\n",
      "\n",
      "avg / total       0.51      0.66      0.55        38\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lxy/.local/lib/python3.6/site-packages/sklearn/metrics/classification.py:1135: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n",
      "  'precision', 'predicted', average, warn_for)\n"
     ]
    }
   ],
   "source": [
    "y_predicted = predict(X_test, model)\n",
    "print(y_predicted)\n",
    "accuracy = np.mean(y_predicted == y_test) * 100\n",
    "print(\"在测试集上的准确率 {:.1f}%\".format(accuracy))\n",
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(y_test, y_predicted))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "在测试集上的准确率 65.8%，比完全随机预测 50% 好一点点"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "navigate_num": "#000000",
    "navigate_text": "#333333",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700",
    "sidebar_border": "#EEEEEE",
    "wrapper_background": "#FFFFFF"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "12px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": 4,
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false,
   "widenNotebook": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
